!nvidia-smi
!pip install kornia
!pip install lpips
!pip install git+https://github.com/ChristophReich1996/Involution
# Import necessary libraries
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, SubsetRandomSampler
from torchvision.datasets import CIFAR10
from torchvision import datasets, transforms
from torch.optim import *
import os
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import math
import cv2
import glob
import copy
from sklearn.model_selection import train_test_split
#import wandb
from torchsummary import summary
from skimage.feature import hog
from tqdm import tqdm as tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 42
from numba import jit, cuda
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, Dataset
from PIL import Image
import glob
import albumentations as A
from albumentations.pytorch import ToTensor
# https://kornia.readthedocs.io/en/latest/losses.html
from kornia.losses import ssim, psnr, ssim_loss, psnr_loss
import lpips
import torch
from involution import Involution2d
#involution = Involution2d(in_channels=32, out_channels=64)
#output = involution(torch.rand(1, 32, 128, 128))
# Mount google drive to colab
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!ls -lrt /content/drive/MyDrive/CV_Project/Dataset
!cp /content/drive/MyDrive/CV_Project/Dataset/SOTS.zip ./
!mkdir RESIDE
!cp /content/drive/MyDrive/CV_Project/Dataset/archive.zip ./RESIDE/
%cd /content/
!unzip -qq SOTS.zip
%cd /content/RESIDE
!unzip -qq archive.zip
%cd /content/
!ls -lrt ./RESIDE
!ls -lrt ./SOTS/indoor/
!ls -lrt ./SOTS/
!ls -lrt /content/SOTS/outdoor/
!rm -rf /content/SOTS/outdoor/hazy/0051_0.95_0.12.jpg
!rm -rf /content/SOTS/outdoor/hazy/0076_1_0.16.jpg
!rm -rf /content/SOTS/outdoor/hazy/0086_0.95_0.12.jpg
!rm -rf /content/SOTS/outdoor/hazy/0108_1_0.2.jpg
!rm -rf /content/SOTS/outdoor/hazy/0253_1_0.16.jpg
!rm -rf /content/SOTS/outdoor/hazy/0287_0.95_0.08.jpg
!rm -rf /content/SOTS/outdoor/hazy/0330_0.8_0.08.jpg
!rm -rf /content/SOTS/outdoor/hazy/0320_0.9_0.08.jpg
!ls -lrt
class RESIDEDataset(Dataset):
def __init__(self, path, train = True, transform=None):
self.path = path
self.transform = transform
self.train = train
if self.train:
self.images_hazy = sorted([file for file in glob.glob(self.path + 'hazy/' + '*')])
self.images_clear = sorted([file for file in glob.glob(self.path + 'clear/' + '*')])
self.clear_base_path = self.path + 'clear/'
else:
self.images_hazy = sorted([file for file in glob.glob(self.path + 'hazy/' + '*')])
self.images_clear = sorted([file for file in glob.glob(self.path + 'gt/' + '*')])
self.clear_base_path = self.path + 'gt/'
def __getitem__(self,index):
image_hazy_path = self.images_hazy[index]
#hazy = Image.open(image_hazy_path)
hazy = cv2.imread(image_hazy_path)
hazy = cv2.cvtColor(hazy, cv2.COLOR_BGR2RGB)
#print(image_hazy_path)
clear_hazy_path = self.clear_base_path + image_hazy_path.split('/')[-1].split('_')[0] + '.png'
#print(clear_hazy_path)
#clear = Image.open(clear_hazy_path)
clear = cv2.imread(clear_hazy_path)
clear = cv2.cvtColor(clear, cv2.COLOR_BGR2RGB)
if self.transform:
transformed = self.transform(image=hazy, mask=clear)
hazy_transformed = transformed['image']
clear_transformed = torch.squeeze(transformed['mask']).permute(2,0,1)
return hazy_transformed, clear_transformed
def __len__(self):
return len(self.images_hazy)
batch_size = 64
train_transform = A.Compose(
[
#A.CenterCrop(height=224, width=224),
#A.RandomCrop(height=224, width=224),
A.Resize(height=128, width=128),
A.HorizontalFlip(),
A.Rotate(30),
#A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
A.Normalize(mean=(0.64, 0.6, 0.58),std=(0.14,0.15, 0.152)),
ToTensor(),
])
test_transform = A.Compose(
[
#A.CenterCrop(height=224, width=224),
A.Resize(height=128, width=128),
A.Normalize(mean=(0.64, 0.6, 0.58),std=(0.14,0.15, 0.152)),
ToTensor(),
])
#train_transform = transforms.Compose([transforms.CenterCrop((224,224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(30), transforms.ToTensor(),])
#test_transform = transforms.Compose([transforms.CenterCrop((224,224)), transforms.ToTensor(),])
trainset = RESIDEDataset(path = '/content/RESIDE/', train = True, transform = train_transform)
testset_id = RESIDEDataset(path = '/content/SOTS/indoor/', train = False, transform = test_transform)
testset_od = RESIDEDataset(path = '/content/SOTS/outdoor/', train = False, transform = test_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)
testloader_id = torch.utils.data.DataLoader(testset_id, batch_size=batch_size, shuffle=False, num_workers=2)
testloader_od = torch.utils.data.DataLoader(testset_od, batch_size=batch_size, shuffle=False, num_workers=2)
#%%time
dataiter = iter(trainloader)
images, masks = dataiter.next()
print(type(images))
print(type(masks))
print(images.shape)
print(masks.shape)
import lpips
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
#loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to "traditional" perceptual loss, when used for optimization
### loss_fn_vgg, appears crashing with batch_size of 256
d = loss_fn_alex(images, masks)
torch.mean(d)
images[0].shape
plt.imshow(images[0].permute(1,2,0))
# Display 10 Hazy Images
figure = plt.figure()
num_of_images = 10
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(images[index].numpy().transpose(1, 2, 0))
# Display 10 corresponding Clear Images
figure = plt.figure()
num_of_images = 10
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(masks[index].numpy().transpose(1, 2, 0))
def train_epoch(model, trainloader, criterion, optimizer, lr_scheduler, phase='train'):
model.train()
running_loss = 0.0
epoch_loss = 0.
psnr_score_running = 0.
ssim_score_running = 0.
batch_num = 0.
samples_num = 0.
#true_labels = []
#pred_labels = []
for batch_idx, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
model = model.to(device)
optimizer.zero_grad()
pred_masks = model(inputs)
#print(pred_masks.is_cuda)
#print(labels.is_cuda)
loss_fn_alex = lpips.LPIPS(net='alex', verbose = False).cuda() ### AlexNet perpetual loss
loss_dist = loss_fn_alex(labels, pred_masks)
ssim_loss_batch = ssim_loss(labels, pred_masks, 11) ## Include in composite loss
psnr_loss_batch = psnr_loss(labels, pred_masks, 1) ## Include in composite loss
composite_loss = 0.6 * loss_dist + 0.1 * ssim_loss_batch + 0.3 * psnr_loss_batch
#composite_loss = loss_dist
#loss = torch.mean(loss_dist)
loss = torch.mean(composite_loss)
'''
smooth = 1e-6
intersection = (outputs * labels).sum()
dice = (2.*intersection + smooth)/(outputs.sum() + labels.sum() + smooth)
bce_loss = criterion(outputs, labels)
total = (outputs + labels).sum()
union = total - intersection
IoU = (intersection + smooth)/(union + smooth)
loss = 1 - dice + bce_loss
'''
#pred_mask = outputs
#true_labels.append(labels)
#pred_labels.append(pred_masks)
loss.backward()
optimizer.step()
####print(f'\r{phase} batch [{batch_idx}/{len(trainloader)}]: loss {torch.mean(loss).item()}', end='', flush=True)
epoch_loss += torch.mean(loss.detach().cpu()).item()
ssim_score = torch.mean(ssim(labels, pred_masks, 11))
psnr_score = psnr(labels, pred_masks, 1)
ssim_score_running += ssim_score.detach().cpu().item()*len(labels)
psnr_score_running += psnr_score.detach().cpu().item()*len(labels)
batch_num += 1
samples_num += len(labels)
#print(f1_running / samples_num)
return epoch_loss / batch_num, ssim_score_running / samples_num, psnr_score_running/ samples_num
#return
def test_epoch(model, testloader, criterion, optimizer, lr_scheduler, phase='test'):
model.eval()
epoch_loss = 0.
#epoch_acc = 0.
batch_num = 0.
samples_num = 0.
psnr_score_running = 0.
ssim_score_running = 0.
#true_labels = []
#pred_labels = []
with torch.no_grad():
for batch_idx, data in enumerate(testloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
model = model.to(device)
pred_masks = model(inputs)
loss_fn_alex = lpips.LPIPS(net='alex', verbose = False).cuda() ### AlexNet perpetual loss
loss_dist = loss_fn_alex(labels, pred_masks)
ssim_loss_batch = ssim_loss(labels, pred_masks, 11) ## Include in composite loss
psnr_loss_batch = psnr_loss(labels, pred_masks, 1) ## Include in composite loss
composite_loss = 0.6 * loss_dist + 0.1 * ssim_loss_batch + 0.3 * psnr_loss_batch
#composite_loss = loss_dist
#loss = torch.mean(loss_dist)
loss = torch.mean(composite_loss)
####print(f'\r{phase} batch [{batch_idx}/{len(testloader)}]: loss {torch.mean(loss).item()}', end='', flush=True)
epoch_loss += torch.mean(loss.detach().cpu()).item()
ssim_score = torch.mean(ssim(labels, pred_masks, 11))
psnr_score = psnr(labels, pred_masks, 1)
ssim_score_running += ssim_score.detach().cpu().item()*len(labels)
psnr_score_running += psnr_score.detach().cpu().item()*len(labels)
batch_num += 1
samples_num += len(labels)
return epoch_loss / batch_num, ssim_score_running / samples_num, psnr_score_running/ samples_num
def train_model(model, train_loader, test_loader_id, test_loader_od, criterion, optimizer, lr_scheduler, epochs):
train_losses = []
train_ssims = []
train_psnrs = []
test_losses_id = []
test_ssims_id = []
test_psnrs_id = []
test_losses_od = []
test_ssims_od = []
test_psnrs_od = []
best_loss = 0
best_model = None
for epoch in range(epochs):
print('='*15, f'Epoch: {epoch}')
train_loss, train_ssim, train_psnr = train_epoch(model, train_loader, criterion, optimizer, lr_scheduler, phase='train')
test_loss_id, test_ssim_id, test_psnr_id = test_epoch(model, test_loader_id, criterion, optimizer, lr_scheduler, phase='test')
test_loss_od, test_ssim_od, test_psnr_od = test_epoch(model, test_loader_od, criterion, optimizer, lr_scheduler, phase='test')
#testloader_id
#lr_scheduler.step()
print()
print(f'Train loss: {train_loss}, Train SSIM: {train_ssim}, Train PSNR: {train_psnr}')
print(f'Test loss SOTS Indoor: {test_loss_id}, Test SSIM SOTS Indoor: {test_ssim_id}, Test PSNR SOTS Indoor: {test_psnr_id}')
print(f'Test loss SOTS Outdoor: {test_loss_od}, Test SSIM SOTS Outdoor: {test_ssim_od}, Test PSNR SOTS Outdoor: {test_psnr_od}')
print()
train_losses.append(train_loss)
train_ssims.append(train_ssim)
train_psnrs.append(train_psnr)
test_losses_id.append(test_loss_id)
test_ssims_id.append(test_ssim_id)
test_psnrs_id.append(test_psnr_id)
test_losses_od.append(test_loss_od)
test_ssims_od.append(test_ssim_od)
test_psnrs_od.append(test_psnr_od)
torch.save({'epoch': epoch, 'model': model.state_dict()}, f'/content/drive/MyDrive/CV_Project/CheckPoints_U/r3/unet-{epoch}.pt')
'''
if best_model is None or test_loss < best_loss:
best_model = copy.deepcopy(model)
best_loss = test_loss
#best_test_acc = test_dice
#best_pred_labels = pred_labels
torch.save({'epoch': epoch, 'model': model.state_dict()}, f'nn-{seed}.pt')
'''
return train_losses,train_ssims,train_psnrs,test_losses_id,test_ssims_id,test_psnrs_id,test_losses_od,test_ssims_od,test_psnrs_od
class Block_en(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding = 1)
self.inonv1 = Involution2d(in_channels=in_ch, out_channels=out_ch, kernel_size = (3,3), padding = (1,1))
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding = 1)
self.inonv2 = Involution2d(in_channels=out_ch, out_channels=out_ch, kernel_size = (3,3), padding = (1,1))
def forward(self, x):
#return self.relu(self.conv2(self.relu(self.conv1(x))))
#print(self.inonv1(x).shape)
return self.relu(self.conv2(self.relu(self.inonv1(x))))
#return self.relu(self.inonv2(self.relu(self.inonv1(x))))
class Block_de(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding = 1)
#self.inonv1 = Involution2d(in_channels=in_ch, out_channels=out_ch, kernel_size = (3,3), padding = (1,1))
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding = 1)
#self.inonv2 = Involution2d(in_channels=out_ch, out_channels=out_ch, kernel_size = (3,3), padding = (1,1))
def forward(self, x):
return self.relu(self.conv2(self.relu(self.conv1(x))))
#print(self.inonv1(x).shape)
#return self.relu(self.conv2(self.relu(self.inonv1(x))))
class Encoder(nn.Module):
def __init__(self, chs=(3,64,128,256,512,1024)):
super().__init__()
self.enc_blocks = nn.ModuleList([Block_en(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.pool = nn.MaxPool2d(2)
def forward(self, x):
ftrs = []
for block in self.enc_blocks:
x = block(x)
ftrs.append(x)
x = self.pool(x)
return ftrs
class Decoder(nn.Module):
def __init__(self, chs=(1024, 512, 256, 128, 64)):
super().__init__()
self.chs = chs
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
self.dec_blocks = nn.ModuleList([Block_de(chs[i], chs[i+1]) for i in range(len(chs)-1)])
def forward(self, x, encoder_features):
for i in range(len(self.chs)-1):
x = self.upconvs[i](x)
enc_ftrs = self.crop(encoder_features[i], x)
x = torch.cat([x, enc_ftrs], dim=1)
x = self.dec_blocks[i](x)
return x
def crop(self, enc_ftrs, x):
_, _, H, W = x.shape
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
return enc_ftrs
class InvolutionUNet(nn.Module):
def __init__(self, enc_chs=(3,64,128,256,512), dec_chs=(512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
super().__init__()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs)
self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
self.retain_dim = retain_dim
self.out_sz = out_sz
def forward(self, x):
enc_ftrs = self.encoder(x)
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
out = self.head(out)
if self.retain_dim:
out = F.interpolate(out, self.out_sz)
return out
model_unet = InvolutionUNet(enc_chs=(3,64,128,256), dec_chs=(256, 128, 64), num_class=3, retain_dim=False, out_sz=(128,128))
summary(model_unet.cuda(), (3, 128, 128))
input = torch.randn((1,3,64,64), requires_grad=True)
output = model_unet(input.cuda())
output.shape
##criterion = nn.BCELoss()
criterion = None ## Since loss is hard-coded in the train function
optimizer = torch.optim.SGD(model_unet.parameters(), lr=1e-2, weight_decay=0.01, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
#optimizer = torch.optim.SGD(model_unet.parameters(), lr=1e-2, weight_decay=0.01, momentum=0.9)
#optimizer = torch.optim.Adam(model_unet.parameters(), lr=1e-4)
#lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15, eta_min=0, last_epoch=-1, verbose=False)
#lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, verbose=False)
#lr_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=1,step_size_up=100,mode="exp_range",gamma=0.85)
#lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.05, steps_per_epoch=10, epochs=10, verbose=True)
epochs = 15
train_losses,train_ssims,train_psnrs,test_losses_id,test_ssims_id,test_psnrs_id,test_losses_od,test_ssims_od,test_psnrs_od = train_model(model_unet, trainloader, testloader_id, testloader_od, criterion, optimizer, lr_scheduler, epochs)
test_transform_display = A.Compose(
[
#A.CenterCrop(height=224, width=224),
A.Resize(height=128, width=128),
#A.Normalize(mean=(0.64, 0.6, 0.58),std=(0.14,0.15, 0.152)),
ToTensor(),
])
testset_id_display = RESIDEDataset(path = '/content/SOTS/indoor/', train = False, transform = test_transform_display)
testset_od_display = RESIDEDataset(path = '/content/SOTS/outdoor/', train = False, transform = test_transform_display)
testloader_id_display = torch.utils.data.DataLoader(testset_id_display, batch_size=batch_size, shuffle=False, num_workers=2)
testloader_od_display = torch.utils.data.DataLoader(testset_od_display, batch_size=batch_size, shuffle=False, num_workers=2)
best_model = InvolutionUNet(enc_chs=(3,64,128,256), dec_chs=(256, 128, 64), num_class=3, retain_dim=False, out_sz=(128,128))
best_ckp = torch.load('/content/drive/MyDrive/CV_Project/CheckPoints_U/r3/unet-14.pt')
best_model.load_state_dict(best_ckp['model'])
test_loss_id, test_ssim_id, test_psnr_id = test_epoch(best_model, testloader_id, criterion, optimizer, lr_scheduler, phase='test')
print()
print(f'Test loss SOTS Indoor: {test_loss_id}, Test SSIM SOTS Indoor: {test_ssim_id}, Test PSNR SOTS Indoor: {test_psnr_id}')
print()
dataiter = iter(testloader_id_display)
images_did, masks_did = dataiter.next()
print(type(images_did))
print(type(masks_did))
print(images_did.shape)
print(masks_did.shape)
# Display 5 Hazy Images
figure = plt.figure()
num_of_images = 5
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(images_did[index*10].numpy().transpose(1, 2, 0))
plt.suptitle('Indoor Hazy Images', fontsize=90)
plt.subplots_adjust(top=0.95)
# Display 5 corresponding Clear Images
figure = plt.figure()
num_of_images = 5
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(masks_did[index*10].numpy().transpose(1, 2, 0))
plt.suptitle('Indoor Clear Images', fontsize=90)
plt.subplots_adjust(top=0.95)
dataiter = iter(testloader_id)
images_fid, masks_fid = dataiter.next()
print(type(images_fid))
print(type(masks_fid))
print(images_fid.shape)
print(masks_fid.shape)
best_model.eval()
with torch.no_grad():
images_fid = images_fid.to(device)
best_model = best_model.to(device)
pred_masks_fid = best_model(images_fid)
# Display 5 corresponding Dehazed Images
figure = plt.figure()
num_of_images = 5
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(pred_masks_fid[index*10].cpu().numpy().transpose(1, 2, 0))
plt.suptitle('Indoor Dehazed Images', fontsize=90)
plt.subplots_adjust(top=0.95)
best_model = InvolutionUNet(enc_chs=(3,64,128,256), dec_chs=(256, 128, 64), num_class=3, retain_dim=False, out_sz=(128,128))
best_ckp = torch.load('/content/drive/MyDrive/CV_Project/CheckPoints_U/r3/unet-13.pt')
best_model.load_state_dict(best_ckp['model'])
test_loss_od, test_ssim_od, test_psnr_od = test_epoch(best_model, testloader_od, criterion, optimizer, lr_scheduler, phase='test')
print()
print(f'Test loss SOTS Outdoor: {test_loss_od}, Test SSIM SOTS Outdoor: {test_ssim_od}, Test PSNR SOTS Outdoor: {test_psnr_od}')
print()
dataiter = iter(testloader_od_display)
images_dod, masks_dod = dataiter.next()
print(type(images_dod))
print(type(masks_dod))
print(images_dod.shape)
print(masks_dod.shape)
# Display 5 Hazy Images
figure = plt.figure()
num_of_images = 5
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(images_dod[index].numpy().transpose(1, 2, 0))
plt.suptitle('Outdoor Hazy Images', fontsize=90)
plt.subplots_adjust(top=0.95)
# Display 5 corresponding Clear Images
figure = plt.figure()
num_of_images = 5
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(masks_dod[index].numpy().transpose(1, 2, 0))
plt.suptitle('Outdoor Clear Images', fontsize=90)
plt.subplots_adjust(top=0.95)
dataiter = iter(testloader_od)
images_fod, masks_fod = dataiter.next()
print(type(images_fod))
print(type(masks_fod))
print(images_fod.shape)
print(masks_fod.shape)
best_model.eval()
with torch.no_grad():
images_fod = images_fod.to(device)
best_model = best_model.to(device)
pred_masks_fod = best_model(images_fod)
# Display 5 corresponding Dehazed Images
figure = plt.figure()
num_of_images = 5
plt.figure(figsize=(100,50))
for index in range(0, num_of_images):
plt.subplot(2, 5, index+1)
plt.axis('off')
plt.imshow(pred_masks_fod[index].cpu().numpy().transpose(1, 2, 0))
plt.suptitle('Outdoor Dehazed Images', fontsize=90)
plt.subplots_adjust(top=0.95)